#!/usr/bin/env python3
"""
Phase 7: Robustness
Tables 20-23: Alternative DVs, commodity controls, time splits, placebo tests.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as sp_stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Country Sets ───────────────────────────────────────────────────────────
CCA_ALL = {'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
           'TJK', 'TKM', 'UKR', 'UZB'}
CCA_NON_COMMODITY = {'ARM', 'BLR', 'GEO', 'KGZ', 'MDA', 'MNG', 'TJK', 'UKR'}

# ── Helpers ────────────────────────────────────────────────────────────────
def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
    }

# ── Load Data ──────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 7: ROBUSTNESS")
print("=" * 70)

panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls

base_sample = est.dropna(subset=base_vars + ['ca_gdp']).copy()

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 20: ALTERNATIVE DEPENDENT VARIABLES
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 20: ALTERNATIVE DEPENDENT VARIABLES")
print("=" * 70)

# Construct alternative DVs
est['savings_investment_gap'] = est.get('savings_investment_gap',
    est.get('gross_national_savings_gdp', np.nan) - est.get('gross_investment_gdp', np.nan))

alt_dvs = [
    ('ca_gdp', 'CA/GDP'),
    ('nfa_gdp', 'NFA/GDP'),
    ('savings_investment_gap', 'S-I gap'),
]

alt_rows = []
for dv, dv_label in alt_dvs:
    if dv not in est.columns:
        print(f"  {dv_label}: variable not available")
        continue

    # Full sample
    r_full = run_gls(est, dv, base_vars)
    # Excl CCA non-commodity
    sub = est[~est['iso3'].isin(CCA_NON_COMMODITY)]
    r_excl = run_gls(sub, dv, base_vars)

    if r_full:
        row_full = {'dv': dv_label, 'sample': 'Full', 'n_countries': r_full['n_countries'],
                    'n_obs': r_full['n_obs'], 'r_squared': r_full['r_squared']}
        for v in demo_vars:
            row_full[f'{v}_coef'] = r_full['coefficients'][v]
            row_full[f'{v}_pval'] = r_full['p_values'][v]
        alt_rows.append(row_full)
        print(f"  {dv_label} (full): Z₁={r_full['coefficients']['Z_1']:.2f}{star(r_full['p_values']['Z_1'])}")

    if r_excl:
        row_excl = {'dv': dv_label, 'sample': 'Excl CCA-nc', 'n_countries': r_excl['n_countries'],
                    'n_obs': r_excl['n_obs'], 'r_squared': r_excl['r_squared']}
        for v in demo_vars:
            row_excl[f'{v}_coef'] = r_excl['coefficients'][v]
            row_excl[f'{v}_pval'] = r_excl['p_values'][v]
        alt_rows.append(row_excl)
        print(f"  {dv_label} (excl CCA-nc): Z₁={r_excl['coefficients']['Z_1']:.2f}{star(r_excl['p_values']['Z_1'])}")

alt_df = pd.DataFrame(alt_rows)
alt_df.to_csv(TABLE_DIR / "table20_alternative_dvs.csv", index=False)

md = ["# Table 20: Alternative Dependent Variables\n"]
md.append("| DV | Sample | N_c | Z₁ | p-val | Z₂ | p-val | R² |")
md.append("|---|---|---|---|---|---|---|---|")
for _, row in alt_df.iterrows():
    md.append(f"| {row['dv']} | {row['sample']} | {row['n_countries']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_pval']:.4f} | "
              f"{row['Z_2_coef']:.3f}{star(row['Z_2_pval'])} | {row['Z_2_pval']:.4f} | "
              f"{row['r_squared']:.4f} |")
(TABLE_DIR / "table20_alternative_dvs.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 21: COMMODITY CONTROLS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 21: COMMODITY CONTROLS")
print("=" * 70)

# Use trade_openness as a proxy; create commodity exporter dummy
commodity_exporters = {'AZE', 'KAZ', 'RUS', 'TKM', 'UZB',  # CCA oil/gas
                       'SAU', 'ARE', 'KWT', 'QAT', 'OMN', 'BHR',  # Gulf
                       'NGA', 'AGO', 'GAB', 'COG', 'GNQ',  # Africa oil
                       'VEN', 'ECU', 'BOL', 'TTO',  # LatAm
                       'NOR', 'CAN', 'AUS'}  # Advanced commodity

est['is_commodity_exporter'] = est['iso3'].isin(commodity_exporters).astype(float)
est['Z_1_x_commodity'] = est['Z_1'] * est['is_commodity_exporter']

comm_specs = [
    ("Baseline", base_vars),
    ("+ Commodity dummy", base_vars + ['is_commodity_exporter']),
    ("+ Z₁ × commodity", base_vars + ['is_commodity_exporter', 'Z_1_x_commodity']),
    ("Excl all commodity exporters", base_vars),  # special: filter
]

comm_rows = []
for spec_name, spec_vars in comm_specs:
    if spec_name == "Excl all commodity exporters":
        sub = est[~est['iso3'].isin(commodity_exporters)]
    else:
        sub = est
    r = run_gls(sub, 'ca_gdp', spec_vars)
    if r is None:
        continue
    row = {'specification': spec_name, 'n_countries': r['n_countries'],
           'n_obs': r['n_obs'], 'r_squared': r['r_squared']}
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    if 'is_commodity_exporter' in r['coefficients']:
        row['commodity_coef'] = r['coefficients']['is_commodity_exporter']
        row['commodity_pval'] = r['p_values']['is_commodity_exporter']
    if 'Z_1_x_commodity' in r['coefficients']:
        row['Z1_x_comm_coef'] = r['coefficients']['Z_1_x_commodity']
        row['Z1_x_comm_pval'] = r['p_values']['Z_1_x_commodity']
    comm_rows.append(row)
    print(f"  {spec_name:<35s}: Z₁={r['coefficients']['Z_1']:.2f}{star(r['p_values']['Z_1'])}")

comm_df = pd.DataFrame(comm_rows)
comm_df.to_csv(TABLE_DIR / "table21_commodity_controls.csv", index=False)

md = ["# Table 21: Commodity Controls\n"]
md.append("| Specification | N_c | Z₁ | p-val | R² | Commodity coef | Z₁×Commodity |")
md.append("|---|---|---|---|---|---|---|")
for _, row in comm_df.iterrows():
    cc = f"{row.get('commodity_coef', np.nan):.2f}" if pd.notna(row.get('commodity_coef')) else "—"
    zc = f"{row.get('Z1_x_comm_coef', np.nan):.2f}" if pd.notna(row.get('Z1_x_comm_coef')) else "—"
    md.append(f"| {row['specification']} | {row['n_countries']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_pval']:.4f} | "
              f"{row['r_squared']:.4f} | {cc} | {zc} |")
(TABLE_DIR / "table21_commodity_controls.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 22: TIME SPLITS (PRE-2008 vs POST-2008)
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 22: TIME SPLITS")
print("=" * 70)

time_splits = [
    ("Full sample", None, None),
    ("Pre-2008", None, 2007),
    ("Post-2008", 2008, None),
    ("1990-2004", 1990, 2004),
    ("2005-2024", 2005, 2024),
]

split_rows = []
for name, yr_min, yr_max in time_splits:
    sub = base_sample.copy()
    if yr_min:
        sub = sub[sub['year'] >= yr_min]
    if yr_max:
        sub = sub[sub['year'] <= yr_max]

    r_full = run_gls(sub, 'ca_gdp', base_vars)
    sub_excl = sub[~sub['iso3'].isin(CCA_NON_COMMODITY)]
    r_excl = run_gls(sub_excl, 'ca_gdp', base_vars)

    for sample_name, r in [("Full", r_full), ("Excl CCA-nc", r_excl)]:
        if r is None:
            continue
        row = {'period': name, 'sample': sample_name,
               'n_countries': r['n_countries'], 'n_obs': r['n_obs'],
               'r_squared': r['r_squared']}
        for v in demo_vars:
            row[f'{v}_coef'] = r['coefficients'][v]
            row[f'{v}_se'] = r['std_errors'][v]
            row[f'{v}_pval'] = r['p_values'][v]
        split_rows.append(row)
        sig = star(row['Z_1_pval'])
        print(f"  {name:<15s} ({sample_name:<12s}): Z₁={row['Z_1_coef']:7.2f}{sig} "
              f"(p={row['Z_1_pval']:.4f}), N_c={r['n_countries']}")

split_df = pd.DataFrame(split_rows)
split_df.to_csv(TABLE_DIR / "table22_time_splits.csv", index=False)

md = ["# Table 22: Time Splits\n"]
md.append("| Period | Sample | N_c | N_obs | Z₁ | SE | p-val | R² |")
md.append("|---|---|---|---|---|---|---|---|")
for _, row in split_df.iterrows():
    md.append(f"| {row['period']} | {row['sample']} | {row['n_countries']} | {row['n_obs']} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_se']:.2f} | "
              f"{row['Z_1_pval']:.4f} | {row['r_squared']:.4f} |")
(TABLE_DIR / "table22_time_splits.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 23: PLACEBO TESTS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 23: PLACEBO TESTS")
print("=" * 70)

# Placebo: randomly assign "CCA non-commodity" label to 8 non-CCA countries
# with similar income levels, then re-test fragility
# Repeat N_BOOT times and compare distribution of Z₁ changes

# True effect: Z₁ change when dropping actual CCA-nc
r_full_ref = run_gls(base_sample, 'ca_gdp', base_vars)
sub_excl_true = base_sample[~base_sample['iso3'].isin(CCA_NON_COMMODITY)]
r_excl_true = run_gls(sub_excl_true, 'ca_gdp', base_vars)

if r_full_ref and r_excl_true:
    true_z1_change = r_excl_true['coefficients']['Z_1'] - r_full_ref['coefficients']['Z_1']
    true_z1_full = r_full_ref['coefficients']['Z_1']
    true_z1_excl = r_excl_true['coefficients']['Z_1']
    print(f"  True Z₁ change (drop CCA-nc): {true_z1_full:.2f} → {true_z1_excl:.2f} (Δ={true_z1_change:.2f})")

# Get non-CCA countries sorted by income (for matching)
non_cca_countries = sorted(set(base_sample['iso3'].unique()) - CCA_ALL)
country_gdp = base_sample.groupby('iso3')['gdp_pc_ppp'].mean().dropna()

# CCA-nc income range
cca_nc_gdp = country_gdp[country_gdp.index.isin(CCA_NON_COMMODITY)]
if len(cca_nc_gdp) > 0:
    gdp_lo = cca_nc_gdp.min() * 0.3
    gdp_hi = cca_nc_gdp.max() * 3.0
    income_matched = country_gdp[(country_gdp >= gdp_lo) & (country_gdp <= gdp_hi) &
                                  (~country_gdp.index.isin(CCA_ALL))].index.tolist()
    print(f"  Income-matched pool: {len(income_matched)} countries (GDP/c {gdp_lo:.0f}-{gdp_hi:.0f})")
else:
    income_matched = non_cca_countries

N_BOOT = 500
n_to_drop = len(CCA_NON_COMMODITY & set(base_sample['iso3'].unique()))
np.random.seed(42)

placebo_z1_changes = []
for b in range(N_BOOT):
    if len(income_matched) < n_to_drop:
        break
    fake_cca = np.random.choice(income_matched, size=n_to_drop, replace=False)
    sub = base_sample[~base_sample['iso3'].isin(fake_cca)]
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r:
        z1_change = r['coefficients']['Z_1'] - r_full_ref['coefficients']['Z_1']
        placebo_z1_changes.append(z1_change)

if len(placebo_z1_changes) > 0:
    placebo_arr = np.array(placebo_z1_changes)
    pct_more_extreme = np.mean(np.abs(placebo_arr) >= np.abs(true_z1_change))
    print(f"\n  Placebo distribution ({N_BOOT} draws):")
    print(f"    Mean Δ: {placebo_arr.mean():.2f}")
    print(f"    Std Δ:  {placebo_arr.std():.2f}")
    print(f"    True Δ: {true_z1_change:.2f}")
    print(f"    p-value (two-sided): {pct_more_extreme:.4f}")
    print(f"    True effect is {'UNUSUAL' if pct_more_extreme < 0.05 else 'NOT unusual'} "
          f"relative to placebo distribution")

    placebo_results = {
        'true_z1_change': true_z1_change,
        'placebo_mean': placebo_arr.mean(),
        'placebo_std': placebo_arr.std(),
        'placebo_p_value': pct_more_extreme,
        'n_bootstrap': N_BOOT,
        'n_drop': n_to_drop,
        'n_pool': len(income_matched),
        'placebo_p05': np.percentile(placebo_arr, 5),
        'placebo_p95': np.percentile(placebo_arr, 95),
    }
    pd.DataFrame([placebo_results]).to_csv(TABLE_DIR / "table23_placebo.csv", index=False)

    md = ["# Table 23: Placebo Test Results\n"]
    md.append(f"Procedure: Randomly drop {n_to_drop} income-matched non-CCA countries ({N_BOOT} iterations)\n")
    md.append(f"| Statistic | Value |")
    md.append(f"|---|---|")
    md.append(f"| True Z₁ change (drop CCA-nc) | {true_z1_change:.2f} |")
    md.append(f"| Placebo mean Δ | {placebo_arr.mean():.2f} |")
    md.append(f"| Placebo std Δ | {placebo_arr.std():.2f} |")
    md.append(f"| Placebo 5th percentile | {np.percentile(placebo_arr, 5):.2f} |")
    md.append(f"| Placebo 95th percentile | {np.percentile(placebo_arr, 95):.2f} |")
    md.append(f"| p-value (two-sided) | {pct_more_extreme:.4f} |")
    md.append(f"| Conclusion | {'CCA-nc effect is ANOMALOUS' if pct_more_extreme < 0.05 else 'CCA-nc effect is NOT anomalous (could arise from random 8-country drop)'} |")
    (TABLE_DIR / "table23_placebo.md").write_text("\n".join(md))

    # Placebo figure
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.hist(placebo_arr, bins=40, color='steelblue', alpha=0.7, edgecolor='white')
    ax.axvline(true_z1_change, color='red', linewidth=2, linestyle='--',
               label=f'True CCA-nc effect ({true_z1_change:.2f})')
    ax.set_xlabel("Z₁ change (drop 8 countries)")
    ax.set_ylabel("Frequency")
    ax.set_title(f"Placebo Distribution ({N_BOOT} random 8-country drops)")
    ax.legend()
    fig.tight_layout()
    fig.savefig(CCA_DIR / "output" / "figures" / "fig3_placebo.png", dpi=150, bbox_inches="tight")
    plt.close()
    print(f"  Saved figure: {CCA_DIR / 'output' / 'figures' / 'fig3_placebo.png'}")

print(f"\nAll Phase 7 tables saved to {TABLE_DIR}")
print("Phase 7 complete.")
